{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<small><i>This notebook was put together by [Jake Vanderplas](http://www.vanderplas.com). Source and license info is on [GitHub](https://github.com/jakevdp/sklearn_tutorial/).</i></small>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Basic Principles of Machine Learning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we'll dive into the basic principles of machine learning, and how to\n",
    "utilize them via the Scikit-Learn API.\n",
    "\n",
    "After briefly introducing scikit-learn's *Estimator* object, we'll cover **supervised learning**, including *classification* and *regression* problems, and **unsupervised learning**, including *dimensinoality reduction* and *clustering* problems."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "plt.style.use('seaborn')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The Scikit-learn Estimator Object\n",
    "\n",
    "Every algorithm is exposed in scikit-learn via an ''Estimator'' object. For instance a linear regression is implemented as so:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LinearRegression"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Estimator parameters**: All the parameters of an estimator can be set when it is instantiated, and have suitable default values:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = LinearRegression(normalize=True)\n",
    "print(model.normalize)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Estimated Model parameters**: When data is *fit* with an estimator, parameters are estimated from the data at hand. All the estimated parameters are attributes of the estimator object ending by an underscore:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = np.arange(10)\n",
    "y = 2 * x + 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(x)\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(x, y, 'o');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The input data for sklearn is 2D: (samples == 10 x features == 1)\n",
    "X = x[:, np.newaxis]\n",
    "print(X)\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# fit the model on our data\n",
    "model.fit(X, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# underscore at the end indicates a fit parameter\n",
    "print(model.coef_)\n",
    "print(model.intercept_)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model found a line with a slope 2 and intercept 1, as we'd expect."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Supervised Learning: Classification and Regression\n",
    "\n",
    "In **Supervised Learning**, we have a dataset consisting of both features and labels.\n",
    "The task is to construct an estimator which is able to predict the label of an object\n",
    "given the set of features. A relatively simple example is predicting the species of \n",
    "iris given a set of measurements of its flower. This is a relatively simple task. \n",
    "Some more complicated examples are:\n",
    "\n",
    "- given a multicolor image of an object through a telescope, determine\n",
    "  whether that object is a star, a quasar, or a galaxy.\n",
    "- given a photograph of a person, identify the person in the photo.\n",
    "- given a list of movies a person has watched and their personal rating\n",
    "  of the movie, recommend a list of movies they would like\n",
    "  (So-called *recommender systems*: a famous example is the [Netflix Prize](http://en.wikipedia.org/wiki/Netflix_prize)).\n",
    "\n",
    "What these tasks have in common is that there is one or more unknown\n",
    "quantities associated with the object which needs to be determined from other\n",
    "observed quantities.\n",
    "\n",
    "Supervised learning is further broken down into two categories, **classification** and **regression**.\n",
    "In classification, the label is discrete, while in regression, the label is continuous. For example,\n",
    "in astronomy, the task of determining whether an object is a star, a galaxy, or a quasar is a\n",
    "classification problem: the label is from three distinct categories. On the other hand, we might\n",
    "wish to estimate the age of an object based on such observations: this would be a regression problem,\n",
    "because the label (age) is a continuous quantity."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Classification Example\n",
    "K nearest neighbors (kNN) is one of the simplest learning strategies: given a new, unknown observation, look up in your reference database which ones have the closest features and assign the predominant class.\n",
    "\n",
    "Let's try it out on our iris classification problem:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn import neighbors, datasets\n",
    "\n",
    "iris = datasets.load_iris()\n",
    "X, y = iris.data, iris.target\n",
    "\n",
    "# create the model\n",
    "knn = neighbors.KNeighborsClassifier(n_neighbors=5)\n",
    "\n",
    "# fit the model\n",
    "knn.fit(X, y)\n",
    "\n",
    "# What kind of iris has 3cm x 5cm sepal and 4cm x 2cm petal?\n",
    "# call the \"predict\" method:\n",
    "result = knn.predict([[3, 5, 4, 2],])\n",
    "\n",
    "print(iris.target_names[result])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also do probabilistic predictions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "knn.predict_proba([[3, 5, 4, 2],])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fig_code import plot_iris_knn\n",
    "plot_iris_knn()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "#### Exercise\n",
    "\n",
    "Use a different estimator on the same problem: ``sklearn.svm.SVC``.\n",
    "\n",
    "*Note that you don't have to know what it is do use it. We're simply trying out the interface here*\n",
    "\n",
    "*If you finish early, try to create a similar plot as above with the SVC estimator.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.svm import SVC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Regression Example\n",
    "\n",
    "One of the simplest regression problems is fitting a line to data, which we saw above.\n",
    "Scikit-learn also contains more sophisticated regression algorithms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create some simple data\n",
    "import numpy as np\n",
    "np.random.seed(0)\n",
    "X = np.random.random(size=(20, 1))\n",
    "y = 3 * X.squeeze() + 2 + np.random.randn(20)\n",
    "\n",
    "plt.plot(X.squeeze(), y, 'o');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As above, we can plot a line of best fit:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = LinearRegression()\n",
    "model.fit(X, y)\n",
    "\n",
    "# Plot the data and the model prediction\n",
    "X_fit = np.linspace(0, 1, 100)[:, np.newaxis]\n",
    "y_fit = model.predict(X_fit)\n",
    "\n",
    "plt.plot(X.squeeze(), y, 'o')\n",
    "plt.plot(X_fit.squeeze(), y_fit);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Scikit-learn also has some more sophisticated models, which can respond to finer features in the data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Fit a Random Forest\n",
    "from sklearn.ensemble import RandomForestRegressor\n",
    "model = RandomForestRegressor()\n",
    "model.fit(X, y)\n",
    "\n",
    "# Plot the data and the model prediction\n",
    "X_fit = np.linspace(0, 1, 100)[:, np.newaxis]\n",
    "y_fit = model.predict(X_fit)\n",
    "\n",
    "plt.plot(X.squeeze(), y, 'o')\n",
    "plt.plot(X_fit.squeeze(), y_fit);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Whether either of these is a \"good\" fit or not depends on a number of things; we'll discuss details of how to choose a model later in the tutorial."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "#### Exercise\n",
    "\n",
    "Explore the ``RandomForestRegressor`` object using IPython's help features (i.e. put a question mark after the object).\n",
    "What arguments are available to ``RandomForestRegressor``?\n",
    "How does the above plot change if you change these arguments?\n",
    "\n",
    "These class-level arguments are known as *hyperparameters*, and we will discuss later how you to select hyperparameters in the model validation section.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Unsupervised Learning: Dimensionality Reduction and Clustering\n",
    "\n",
    "**Unsupervised Learning** addresses a different sort of problem. Here the data has no labels,\n",
    "and we are interested in finding similarities between the objects in question. In a sense,\n",
    "you can think of unsupervised learning as a means of discovering labels from the data itself.\n",
    "Unsupervised learning comprises tasks such as *dimensionality reduction*, *clustering*, and\n",
    "*density estimation*. For example, in the iris data discussed above, we can used unsupervised\n",
    "methods to determine combinations of the measurements which best display the structure of the\n",
    "data. As we'll see below, such a projection of the data can be used to visualize the\n",
    "four-dimensional dataset in two dimensions. Some more involved unsupervised learning problems are:\n",
    "\n",
    "- given detailed observations of distant galaxies, determine which features or combinations of\n",
    "  features best summarize the information.\n",
    "- given a mixture of two sound sources (for example, a person talking over some music),\n",
    "  separate the two (this is called the [blind source separation](http://en.wikipedia.org/wiki/Blind_signal_separation) problem).\n",
    "- given a video, isolate a moving object and categorize in relation to other moving objects which have been seen.\n",
    "\n",
    "Sometimes the two may even be combined: e.g. Unsupervised learning can be used to find useful\n",
    "features in heterogeneous data, and then these features can be used within a supervised\n",
    "framework."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dimensionality Reduction: PCA\n",
    "\n",
    "Principle Component Analysis (PCA) is a dimension reduction technique that can find the combinations of variables that explain the most variance.\n",
    "\n",
    "Consider the iris dataset. It cannot be visualized in a single 2D plot, as it has 4 features. We are going to extract 2 combinations of sepal and petal dimensions to visualize it:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X, y = iris.data, iris.target\n",
    "\n",
    "from sklearn.decomposition import PCA\n",
    "pca = PCA(n_components=0.95)\n",
    "pca.fit(X)\n",
    "X_reduced = pca.transform(X)\n",
    "print(\"Reduced dataset shape:\", X_reduced.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pylab as plt\n",
    "plt.scatter(X_reduced[:, 0], X_reduced[:, 1], c=y,\n",
    "           cmap='RdYlBu')\n",
    "\n",
    "print(\"Meaning of the 2 components:\")\n",
    "for component in pca.components_:\n",
    "    print(\" + \".join(\"%.3f x %s\" % (value, name)\n",
    "                     for value, name in zip(component,\n",
    "                                            iris.feature_names)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Clustering: K-means\n",
    "\n",
    "Clustering groups together observations that are homogeneous with respect to a given criterion, finding ''clusters'' in the data.\n",
    "\n",
    "Note that these clusters will uncover relevent hidden structure of the data only if the criterion used highlights it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.cluster import KMeans\n",
    "k_means = KMeans(n_clusters=3, random_state=0) # Fixing the RNG in kmeans\n",
    "k_means.fit(X)\n",
    "y_pred = k_means.predict(X)\n",
    "\n",
    "plt.scatter(X_reduced[:, 0], X_reduced[:, 1], c=y_pred,\n",
    "           cmap='RdYlBu');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Recap: Scikit-learn's estimator interface\n",
    "\n",
    "Scikit-learn strives to have a uniform interface across all methods,\n",
    "and we'll see examples of these below. Given a scikit-learn *estimator*\n",
    "object named `model`, the following methods are available:\n",
    "\n",
    "- Available in **all Estimators**\n",
    "  + `model.fit()` : fit training data. For supervised learning applications,\n",
    "    this accepts two arguments: the data `X` and the labels `y` (e.g. `model.fit(X, y)`).\n",
    "    For unsupervised learning applications, this accepts only a single argument,\n",
    "    the data `X` (e.g. `model.fit(X)`).\n",
    "- Available in **supervised estimators**\n",
    "  + `model.predict()` : given a trained model, predict the label of a new set of data.\n",
    "    This method accepts one argument, the new data `X_new` (e.g. `model.predict(X_new)`),\n",
    "    and returns the learned label for each object in the array.\n",
    "  + `model.predict_proba()` : For classification problems, some estimators also provide\n",
    "    this method, which returns the probability that a new observation has each categorical label.\n",
    "    In this case, the label with the highest probability is returned by `model.predict()`.\n",
    "  + `model.score()` : for classification or regression problems, most (all?) estimators implement\n",
    "    a score method.  Scores are between 0 and 1, with a larger score indicating a better fit.\n",
    "- Available in **unsupervised estimators**\n",
    "  + `model.predict()` : predict labels in clustering algorithms.\n",
    "  + `model.transform()` : given an unsupervised model, transform new data into the new basis.\n",
    "    This also accepts one argument `X_new`, and returns the new representation of the data based\n",
    "    on the unsupervised model.\n",
    "  + `model.fit_transform()` : some estimators implement this method,\n",
    "    which more efficiently performs a fit and a transform on the same input data."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Validation\n",
    "\n",
    "An important piece of machine learning is **model validation**: that is, determining how well your model will generalize from the training data to future unlabeled data. Let's look at an example using the *nearest neighbor classifier*. This is a very simple classifier: it simply stores all training data, and for any unknown quantity, simply returns the label of the closest training point.\n",
    "\n",
    "With the iris data, it very easily returns the correct prediction for each of the input points:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "X, y = iris.data, iris.target\n",
    "clf = KNeighborsClassifier(n_neighbors=1)\n",
    "clf.fit(X, y)\n",
    "y_pred = clf.predict(X)\n",
    "print(np.all(y == y_pred))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A more useful way to look at the results is to view the **confusion matrix**, or the matrix showing the frequency of inputs and outputs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import confusion_matrix\n",
    "print(confusion_matrix(y, y_pred))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For each class, all 50 training samples are correctly identified. But this **does not mean that our model is perfect!** In particular, such a model generalizes extremely poorly to new data. We can simulate this by splitting our data into a *training set* and a *testing set*. Scikit-learn contains some convenient routines to do this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "Xtrain, Xtest, ytrain, ytest = train_test_split(X, y)\n",
    "clf.fit(Xtrain, ytrain)\n",
    "ypred = clf.predict(Xtest)\n",
    "print(confusion_matrix(ytest, ypred))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This paints a better picture of the true performance of our classifier: apparently there is some confusion between the second and third species, which we might anticipate given what we've seen of the data above.\n",
    "\n",
    "This is why it's **extremely important** to use a train/test split when evaluating your models.  We'll go into more depth on model evaluation later in this tutorial."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Flow Chart: How to Choose your Estimator\n",
    "\n",
    "This is a flow chart created by scikit-learn super-contributor [Andreas Mueller](https://github.com/amueller) which gives a nice summary of which algorithms to choose in various situations. Keep it around as a handy reference!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import Image\n",
    "Image(\"http://scikit-learn.org/dev/_static/ml_map.png\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Original source on the [scikit-learn website](http://scikit-learn.org/stable/tutorial/machine_learning_map/)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Quick Application: Optical Character Recognition"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To demonstrate the above principles on a more interesting problem, let's consider OCR (Optical Character Recognition) – that is, recognizing hand-written digits.\n",
    "In the wild, this problem involves both locating and identifying characters in an image. Here we'll take a shortcut and use scikit-learn's set of pre-formatted digits, which is built-in to the library."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading and visualizing the digits data\n",
    "\n",
    "We'll use scikit-learn's data access interface and take a look at this data:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn import datasets\n",
    "digits = datasets.load_digits()\n",
    "digits.images.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's plot a few of these:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(10, 10, figsize=(8, 8))\n",
    "fig.subplots_adjust(hspace=0.1, wspace=0.1)\n",
    "\n",
    "for i, ax in enumerate(axes.flat):\n",
    "    ax.imshow(digits.images[i], cmap='binary', interpolation='nearest')\n",
    "    ax.text(0.05, 0.05, str(digits.target[i]),\n",
    "            transform=ax.transAxes, color='green')\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here the data is simply each pixel value within an 8x8 grid:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The images themselves\n",
    "print(digits.images.shape)\n",
    "print(digits.images[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The data for use in our algorithms\n",
    "print(digits.data.shape)\n",
    "print(digits.data[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The target label\n",
    "print(digits.target)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So our data have 1797 samples in 64 dimensions."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Unsupervised Learning: Dimensionality Reduction\n",
    "\n",
    "We'd like to visualize our points within the 64-dimensional parameter space, but it's difficult to plot points in 64 dimensions!\n",
    "Instead we'll reduce the dimensions to 2, using an unsupervised method.\n",
    "Here, we'll make use of a manifold learning algorithm called *Isomap*, and transform the data to two dimensions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.manifold import Isomap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iso = Isomap(n_components=2)\n",
    "data_projected = iso.fit_transform(digits.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_projected.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(data_projected[:, 0], data_projected[:, 1], c=digits.target,\n",
    "            edgecolor='none', alpha=0.5, cmap=plt.cm.get_cmap('nipy_spectral', 10));\n",
    "plt.colorbar(label='digit label', ticks=range(10))\n",
    "plt.clim(-0.5, 9.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see here that the digits are fairly well-separated in the parameter space; this tells us that a supervised classification algorithm should perform fairly well. Let's give it a try."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Classification on Digits\n",
    "\n",
    "Let's try a classification task on the digits. The first thing we'll want to do is split the digits into a training and testing sample:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "Xtrain, Xtest, ytrain, ytest = train_test_split(digits.data, digits.target,\n",
    "                                                random_state=2)\n",
    "print(Xtrain.shape, Xtest.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's use a simple logistic regression which (despite its confusing name) is a classification algorithm:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "clf = LogisticRegression(penalty='l2')\n",
    "clf.fit(Xtrain, ytrain)\n",
    "ypred = clf.predict(Xtest)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can check our classification accuracy by comparing the true values of the test set to the predictions:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "accuracy_score(ytest, ypred)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This single number doesn't tell us **where** we've gone wrong: one nice way to do this is to use the *confusion matrix*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import confusion_matrix\n",
    "print(confusion_matrix(ytest, ypred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(np.log(confusion_matrix(ytest, ypred)),\n",
    "           cmap='Blues', interpolation='nearest')\n",
    "plt.grid(False)\n",
    "plt.ylabel('true')\n",
    "plt.xlabel('predicted');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We might also take a look at some of the outputs along with their predicted labels. We'll make the bad labels red:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(10, 10, figsize=(8, 8))\n",
    "fig.subplots_adjust(hspace=0.1, wspace=0.1)\n",
    "\n",
    "for i, ax in enumerate(axes.flat):\n",
    "    ax.imshow(Xtest[i].reshape(8, 8), cmap='binary')\n",
    "    ax.text(0.05, 0.05, str(ypred[i]),\n",
    "            transform=ax.transAxes,\n",
    "            color='green' if (ytest[i] == ypred[i]) else 'red')\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The interesting thing is that even with this simple logistic regression algorithm, many of the mislabeled cases are ones that we ourselves might get wrong!\n",
    "\n",
    "There are many ways to improve this classifier, but we're out of time here. To go further, we could use a more sophisticated model, use cross validation, or apply other techniques.\n",
    "We'll cover some of these topics later in the tutorial."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
